{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[(0, 'alt.atheism'), (1, 'comp.graphics'), (2, 'comp.os.ms-windows.misc'), (3, 'comp.sys.ibm.pc.hardware'), (4, 'comp.sys.mac.hardware'), (5, 'comp.windows.x'), (6, 'misc.forsale'), (7, 'rec.autos'), (8, 'rec.motorcycles'), (9, 'rec.sport.baseball'), (10, 'rec.sport.hockey'), (11, 'sci.crypt'), (12, 'sci.electronics'), (13, 'sci.med'), (14, 'sci.space'), (15, 'soc.religion.christian'), (16, 'talk.politics.guns'), (17, 'talk.politics.mideast'), (18, 'talk.politics.misc'), (19, 'talk.religion.misc')]\n"
     ]
    }
   ],
   "source": [
    "from sklearn.datasets import fetch_20newsgroups\n",
    "data20 = fetch_20newsgroups(subset='all', shuffle=True, remove=('headers', 'footers', 'quotes'))\n",
    "print(list(enumerate(data20.target_names)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "baseball = np.where(data20.target == 9)[0]\n",
    "hockey = np.where(data20.target == 10)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "strings = [data20.data[i] for i in list(baseball) + list(hockey)]\n",
    "target = [0 if data20.target[i] == 9 else 1 for i in list(baseball) + list(hockey)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from Transparency.preprocess.vectorizer import cleaner\n",
    "import re\n",
    "\n",
    "def cleaner_20(text) :\n",
    "    text = cleaner(text)\n",
    "    text = re.sub(r'(\\W)+', r' \\1 ', text)\n",
    "    text = re.sub(r'\\s+', ' ', text)\n",
    "    return text.strip()\n",
    "\n",
    "strings = [cleaner_20(s) for s in strings]\n",
    "strings, target = zip(*[(s, t) for s, t in zip(strings, target) if len(s) != 0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "train_idx, test_idx = train_test_split(range(len(strings)), stratify=target, test_size=0.2, random_state=13478)\n",
    "train_idx, dev_idx = train_test_split(train_idx, stratify=[target[i] for i in train_idx], test_size=0.2, random_state=13478)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = [strings[i] for i in train_idx]\n",
    "X_dev = [strings[i] for i in dev_idx]\n",
    "X_test = [strings[i] for i in test_idx]\n",
    "\n",
    "y_train = [target[i] for i in train_idx]\n",
    "y_dev = [target[i] for i in dev_idx]\n",
    "y_test = [target[i] for i in test_idx]\n",
    "\n",
    "texts = { 'train' : X_train, 'test' : X_test, 'dev' : X_dev }\n",
    "labels = { 'train' : y_train, 'test' : y_test, 'dev' : y_dev }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "df_texts = []\n",
    "df_labels = []\n",
    "df_exp_splits = []\n",
    "\n",
    "for key in ['train', 'test', 'dev'] :\n",
    "    df_texts += texts[key]\n",
    "    df_labels += labels[key]\n",
    "    df_exp_splits += [key] * len(texts[key])\n",
    "    \n",
    "df = pd.DataFrame({'text' : df_texts, 'label' : df_labels, 'exp_split' : df_exp_splits})\n",
    "df.to_csv('20News_sports_dataset.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Vocabulary size :  6515\n",
      "Found 5771 words in model out of 6515\n"
     ]
    }
   ],
   "source": [
    "%run \"../preprocess_data_BC.py\" --data_file 20News_sports_dataset.csv \\\n",
    "--output_file ./vec_20news_sports.p --word_vectors_type fasttext.simple.300d --min_df 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
